Skip to content

Commit 680970e

Browse files
committed
review: type hint fixes
fix
1 parent 420febf commit 680970e

File tree

2 files changed

+74
-69
lines changed

2 files changed

+74
-69
lines changed

nemoguardrails/rails/llm/llmrails.py

Lines changed: 72 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -923,6 +923,19 @@ async def generate_async(
923923
The completion (when a prompt is provided) or the next message.
924924
925925
System messages are not yet supported."""
926+
# convert options to gen_options of type GenerationOptions
927+
gen_options: Optional[GenerationOptions] = None
928+
929+
if prompt is None and messages is None:
930+
raise ValueError("Either prompt or messages must be provided.")
931+
932+
if prompt is not None and messages is not None:
933+
raise ValueError("Only one of prompt or messages can be provided.")
934+
935+
if prompt is not None:
936+
# Currently, we transform the prompt request into a single turn conversation
937+
messages = [{"role": "user", "content": prompt}]
938+
926939
# If a state object is specified, then we switch to "generation options" mode.
927940
# This is because we want the output to be a GenerationResponse which will contain
928941
# the output state.
@@ -932,15 +945,25 @@ async def generate_async(
932945
state = json_to_state(state["state"])
933946

934947
if options is None:
935-
options = GenerationOptions()
936-
937-
# We allow options to be specified both as a dict and as an object.
938-
if options and isinstance(options, dict):
939-
options = GenerationOptions(**options)
948+
gen_options = GenerationOptions()
949+
elif isinstance(options, dict):
950+
gen_options = GenerationOptions(**options)
951+
else:
952+
gen_options = options
953+
else:
954+
# We allow options to be specified both as a dict and as an object.
955+
if options and isinstance(options, dict):
956+
gen_options = GenerationOptions(**options)
957+
elif isinstance(options, GenerationOptions):
958+
gen_options = options
959+
elif options is None:
960+
gen_options = None
961+
else:
962+
raise TypeError("options must be a dict or GenerationOptions")
940963

941964
# Save the generation options in the current async context.
942-
# At this point, options is either None or GenerationOptions
943-
generation_options_var.set(options if not isinstance(options, dict) else None)
965+
# At this point, gen_options is either None or GenerationOptions
966+
generation_options_var.set(gen_options)
944967

945968
if streaming_handler:
946969
streaming_handler_var.set(streaming_handler)
@@ -950,23 +973,14 @@ async def generate_async(
950973
# requests are made.
951974
self.explain_info = self._ensure_explain_info()
952975

953-
if prompt is not None:
954-
# Currently, we transform the prompt request into a single turn conversation
955-
messages = [{"role": "user", "content": prompt}]
956-
raw_llm_request.set(prompt)
957-
else:
958-
raw_llm_request.set(messages)
976+
raw_llm_request.set(messages)
959977

960978
# If we have generation options, we also add them to the context
961-
if options:
979+
if gen_options:
962980
messages = [
963981
{
964982
"role": "context",
965-
"content": {
966-
"generation_options": getattr(
967-
options, "dict", lambda: options
968-
)()
969-
},
983+
"content": {"generation_options": gen_options.model_dump()},
970984
}
971985
] + (messages or [])
972986

@@ -976,9 +990,8 @@ async def generate_async(
976990
if (
977991
messages
978992
and messages[-1]["role"] == "assistant"
979-
and options
980-
and hasattr(options, "rails")
981-
and getattr(getattr(options, "rails", None), "dialog", None) is False
993+
and gen_options
994+
and gen_options.rails.dialog is False
982995
):
983996
# We already have the first message with a context update, so we use that
984997
messages[0]["content"]["bot_message"] = messages[-1]["content"]
@@ -995,7 +1008,7 @@ async def generate_async(
9951008
processing_log = []
9961009

9971010
# The array of events corresponding to the provided sequence of messages.
998-
events = self._get_events_for_messages(messages or [], state)
1011+
events = self._get_events_for_messages(messages, state) # type: ignore
9991012

10001013
if self.config.colang_version == "1.0":
10011014
# If we had a state object, we also need to prepend the events from the state.
@@ -1114,7 +1127,7 @@ async def generate_async(
11141127
# If a state object is not used, then we use the implicit caching
11151128
if state is None:
11161129
# Save the new events in the history and update the cache
1117-
cache_key = get_history_cache_key((messages or []) + [new_message])
1130+
cache_key = get_history_cache_key((messages) + [new_message]) # type: ignore
11181131
self.events_history_cache[cache_key] = events
11191132
else:
11201133
output_state = {"events": events}
@@ -1142,33 +1155,29 @@ async def generate_async(
11421155
# IF tracing is enabled we need to set GenerationLog attrs
11431156
original_log_options = None
11441157
if self.config.tracing.enabled:
1145-
if options is None:
1146-
options = GenerationOptions()
1158+
if gen_options is None:
1159+
gen_options = GenerationOptions()
11471160
else:
1148-
# create a copy of the options to avoid modifying the original
1149-
if isinstance(options, GenerationOptions):
1150-
options = options.model_copy(deep=True)
1151-
else:
1152-
# If options is a dict, convert it to GenerationOptions
1153-
options = GenerationOptions(**options)
1154-
original_log_options = options.log.model_copy(deep=True)
1161+
# create a copy of the gen_options to avoid modifying the original
1162+
gen_options = gen_options.model_copy(deep=True)
1163+
original_log_options = gen_options.log.model_copy(deep=True)
11551164

11561165
# enable log options
11571166
# it is aggressive, but these are required for tracing
11581167
if (
1159-
not options.log.activated_rails
1160-
or not options.log.llm_calls
1161-
or not options.log.internal_events
1168+
not gen_options.log.activated_rails
1169+
or not gen_options.log.llm_calls
1170+
or not gen_options.log.internal_events
11621171
):
1163-
options.log.activated_rails = True
1164-
options.log.llm_calls = True
1165-
options.log.internal_events = True
1172+
gen_options.log.activated_rails = True
1173+
gen_options.log.llm_calls = True
1174+
gen_options.log.internal_events = True
11661175

11671176
tool_calls = extract_tool_calls_from_events(new_events)
11681177
llm_metadata = get_and_clear_response_metadata_contextvar()
11691178

11701179
# If we have generation options, we prepare a GenerationResponse instance.
1171-
if options:
1180+
if gen_options:
11721181
# If a prompt was used, we only need to return the content of the message.
11731182
if prompt:
11741183
res = GenerationResponse(response=new_message["content"])
@@ -1195,9 +1204,9 @@ async def generate_async(
11951204

11961205
if self.config.colang_version == "1.0":
11971206
# If output variables are specified, we extract their values
1198-
if getattr(options, "output_vars", None):
1207+
if gen_options and gen_options.output_vars:
11991208
context = compute_context(events)
1200-
output_vars = getattr(options, "output_vars", None)
1209+
output_vars = gen_options.output_vars
12011210
if isinstance(output_vars, list):
12021211
# If we have only a selection of keys, we filter to only that.
12031212
res.output_data = {k: context.get(k) for k in output_vars}
@@ -1208,65 +1217,64 @@ async def generate_async(
12081217
_log = compute_generation_log(processing_log)
12091218

12101219
# Include information about activated rails and LLM calls if requested
1211-
log_options = getattr(options, "log", None)
1220+
log_options = gen_options.log if gen_options else None
12121221
if log_options and (
1213-
getattr(log_options, "activated_rails", False)
1214-
or getattr(log_options, "llm_calls", False)
1222+
log_options.activated_rails or log_options.llm_calls
12151223
):
12161224
res.log = GenerationLog()
12171225

12181226
# We always include the stats
12191227
res.log.stats = _log.stats
12201228

1221-
if getattr(log_options, "activated_rails", False):
1229+
if log_options.activated_rails:
12221230
res.log.activated_rails = _log.activated_rails
12231231

1224-
if getattr(log_options, "llm_calls", False):
1232+
if log_options.llm_calls:
12251233
res.log.llm_calls = []
12261234
for activated_rail in _log.activated_rails:
12271235
for executed_action in activated_rail.executed_actions:
12281236
res.log.llm_calls.extend(executed_action.llm_calls)
12291237

12301238
# Include internal events if requested
1231-
if getattr(log_options, "internal_events", False):
1239+
if log_options and log_options.internal_events:
12321240
if res.log is None:
12331241
res.log = GenerationLog()
12341242

12351243
res.log.internal_events = new_events
12361244

12371245
# Include the Colang history if requested
1238-
if getattr(log_options, "colang_history", False):
1246+
if log_options and log_options.colang_history:
12391247
if res.log is None:
12401248
res.log = GenerationLog()
12411249

12421250
res.log.colang_history = get_colang_history(events)
12431251

12441252
# Include the raw llm output if requested
1245-
if getattr(options, "llm_output", False):
1253+
if gen_options and gen_options.llm_output:
12461254
# Currently, we include the output from the generation LLM calls.
12471255
for activated_rail in _log.activated_rails:
12481256
if activated_rail.type == "generation":
12491257
for executed_action in activated_rail.executed_actions:
12501258
for llm_call in executed_action.llm_calls:
12511259
res.llm_output = llm_call.raw_response
12521260
else:
1253-
if getattr(options, "output_vars", None):
1261+
if gen_options and gen_options.output_vars:
12541262
raise ValueError(
12551263
"The `output_vars` option is not supported for Colang 2.0 configurations."
12561264
)
12571265

1258-
log_options = getattr(options, "log", None)
1266+
log_options = gen_options.log if gen_options else None
12591267
if log_options and (
1260-
getattr(log_options, "activated_rails", False)
1261-
or getattr(log_options, "llm_calls", False)
1262-
or getattr(log_options, "internal_events", False)
1263-
or getattr(log_options, "colang_history", False)
1268+
log_options.activated_rails
1269+
or log_options.llm_calls
1270+
or log_options.internal_events
1271+
or log_options.colang_history
12641272
):
12651273
raise ValueError(
12661274
"The `log` option is not supported for Colang 2.0 configurations."
12671275
)
12681276

1269-
if getattr(options, "llm_output", False):
1277+
if gen_options and gen_options.llm_output:
12701278
raise ValueError(
12711279
"The `llm_output` option is not supported for Colang 2.0 configurations."
12721280
)
@@ -1300,25 +1308,21 @@ async def generate_async(
13001308
if original_log_options:
13011309
if not any(
13021310
(
1303-
getattr(original_log_options, "internal_events", False),
1304-
getattr(original_log_options, "activated_rails", False),
1305-
getattr(original_log_options, "llm_calls", False),
1306-
getattr(original_log_options, "colang_history", False),
1311+
original_log_options.internal_events,
1312+
original_log_options.activated_rails,
1313+
original_log_options.llm_calls,
1314+
original_log_options.colang_history,
13071315
)
13081316
):
13091317
res.log = None
13101318
else:
13111319
# Ensure res.log exists before setting attributes
13121320
if res.log is not None:
1313-
if not getattr(
1314-
original_log_options, "internal_events", False
1315-
):
1321+
if not original_log_options.internal_events:
13161322
res.log.internal_events = []
1317-
if not getattr(
1318-
original_log_options, "activated_rails", False
1319-
):
1323+
if not original_log_options.activated_rails:
13201324
res.log.activated_rails = []
1321-
if not getattr(original_log_options, "llm_calls", False):
1325+
if not original_log_options.llm_calls:
13221326
res.log.llm_calls = []
13231327

13241328
return res

nemoguardrails/rails/llm/options.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
# {..., log: {"llm_calls": [...]}}
7777
7878
"""
79+
7980
from typing import Any, Dict, List, Optional, Union
8081

8182
from pydantic import BaseModel, Field, root_validator
@@ -156,7 +157,7 @@ class GenerationOptions(BaseModel):
156157
default=None,
157158
description="Additional parameters that should be used for the LLM call",
158159
)
159-
llm_output: Optional[bool] = Field(
160+
llm_output: bool = Field(
160161
default=False,
161162
description="Whether the response should also include any custom LLM output.",
162163
)

0 commit comments

Comments
 (0)